package com.za.plugin.transfer.form;


import com.za.plugin.pojo.Content;
import com.za.plugin.util.StrUtil;

/**
 * 对 case when 子句进行改造，每个 then 或 else 后面返回true或false的计算语句进行改造，如果是值语句就跳过
 */
public class CaseWhenFormTransfer implements FormTransfer {
    @Override
    public boolean isSupport(String sql) {
        return !sql.toLowerCase().contains("merge into");
    }

    @Override
    public String transfer(String sql) {
        return addIfFun(sql);
    }


    private static String addIfFun(String sql) {
        sql = sql.replaceAll(" end\\)", " end ) ");
//                .replaceAll("\\(", " ( ").replaceAll("\\)", " ) ");
        if (sql.endsWith(" end")){
            sql=sql+" ";
        }
        if (sql.contains(" then ")) {
            int thenIdx = sql.indexOf(" then ");
            StringBuilder sb = new StringBuilder();
            int left = 0;
            while (thenIdx >= 0) {
                sb.append(sql, left, thenIdx);
                Content content = getSegmentForCaseWhen(sql, thenIdx + 5);
                assert content != null;
                String expectStr = content.getExpectStr();
                int end = content.getEnd();
                // 说明包含 case when 子句
                String thenOrElse = sql.substring(thenIdx, thenIdx + 5) + " ";
                if (expectStr.contains(" when ")) {
                    // when 里有 when 就递归执行本方法，直到让所有的嵌套 when 都符合要求
                    sb.append(thenOrElse).append(addIfFun(expectStr)).append(" ");
                } else {
                    // 对需要加 if() 的子句进行判断，因为如果是值的话会转为 1 ，因为达梦中 TRUE 为 1。
                    //  || (!expectStr.contains("replace")&&!expectStr.contains("ifnull")
                    //    &&!expectStr.contains("if")&&!expectStr.contains("/")&&!expectStr.contains("+")
                    //    &&!expectStr.contains("-")&&!expectStr.contains("*")&&!expectStr.contains("("))
                    if (expectStr.contains("=")) {
                        sb.append(thenOrElse).append(" if(").append(expectStr).append(",true,false) ");
                    } else {
                        sb.append(thenOrElse).append(expectStr).append(" ");
                    }
                }
                left = end;
                thenIdx = findStart(sql, end);
                if (thenIdx == -1) {
                    sb.append(sql.substring(end));
                }
            }
            sql = sb.toString();
        }
        sql = sql.replaceAll(" \\|\\| ", " or ");
        return sql;
    }


    private static Content getSegmentForCaseWhen(String sql, int start) {
        int leftFlag = 0, j = start, len = sql.length();
        while (j < len) {
            if (sql.charAt(j) == '(') {
                leftFlag++;
            }
            if (sql.charAt(j) == ')') {
                leftFlag--;
            }
            if (leftFlag == 0 && (j + 6 <= len) && (sql.substring(j, j + 6).equals(" then ") || sql.substring(j, j + 6).equals(" else ")
                    || sql.substring(j, j + 6).equals(" when "))) {
                String substring = sql.substring(start, j) + " ";
                int endIdx = start;
                while (!checkIsEqual(substring)) {
                    endIdx = sql.indexOf(" end ", endIdx + 3);
                    substring = sql.substring(start, endIdx + 4) + " ";
                }
                endIdx = endIdx == start ? j : endIdx+4;
                return new Content(sql.substring(start, endIdx), endIdx);
            }
            if (leftFlag == 0 && (j + 5) <= len && (sql.substring(j, j + 4).equals(" end"))) {
                return new Content(sql.substring(start, j), j);
            }
            j++;
        }
        return null;
    }

    private static boolean checkIsEqual(String substring) {
        int caseWhenCnt = StrUtil.getCount(substring, "case when");
        int endCnt = StrUtil.getCount(substring, " end ");
        return caseWhenCnt == endCnt;
    }

    private static int findStart(String sql, int start) {
        int thenIdx = sql.indexOf(" then ", start);
        int elseIdx = sql.indexOf(" else ", start);
        int min = thenIdx;
        if (elseIdx != -1 && (min >= elseIdx) || min == -1) {
            min = elseIdx;
        }
        return min;
    }
}
